{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "d956e2b2-02ec-4750-a511-8c4f283add23",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "from transformers import pipeline\n",
    "import pycountry\n",
    "from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline\n",
    "import pandas as pd\n",
    "from tqdm import tqdm\n",
    "import pycountry\n",
    "import logging\n",
    "import country_converter as coco"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "4322d202-e20f-4668-97e1-ade69ebe1a2c",
   "metadata": {},
   "outputs": [],
   "source": [
    "df=pd.read_csv(r'C:\\Users\\Yasaman\\Arab_spring_scholarly_attention\\Validating\\final_annotated.csv')\n",
    "df['Text']=df['Title']+' '+df['Abstract']\n",
    "df['union_annotation'] = df['union_annotation'].apply(lambda x: eval(x) if isinstance(x, str) else x)\n",
    "df['intersection_annotation'] = df['intersection_annotation'].apply(lambda x: eval(x) if isinstance(x, str) else x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "389cf6d4-1d70-4ee9-8a2c-478e4914cba3",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of the model checkpoint at dslim/bert-base-NER were not used when initializing BertForTokenClassification: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight']\n",
      "- This IS expected if you are initializing BertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing BertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
      "100%|██████████| 1000/1000 [11:17<00:00,  1.47it/s]\n"
     ]
    }
   ],
   "source": [
    "# Load tokenizer and model\n",
    "tokenizer = AutoTokenizer.from_pretrained(\"dslim/bert-base-NER\")\n",
    "model = AutoModelForTokenClassification.from_pretrained(\"dslim/bert-base-NER\")\n",
    "\n",
    "nlp = pipeline(\"ner\", model=model, tokenizer=tokenizer)\n",
    "country_names = {country.name for country in pycountry.countries}\n",
    "\n",
    "# Function to extract locations and filter by country names\n",
    "def extract_and_filter_locations(text):\n",
    "    ner_results = nlp(text)\n",
    "    locations = [\n",
    "        entity[\"word\"]\n",
    "        for entity in ner_results\n",
    "        if ( entity[\"entity\"] == \"B-LOC\" or entity[\"entity\"] == \"I-LOC\") \n",
    "    ]\n",
    "\n",
    "    final=coco.convert(list(set(locations)), to='ISO3') \n",
    "    if type(final)==str:\n",
    "        final=[final]\n",
    "\n",
    "    final=[x.lower() for x in final if x!='not found']\n",
    "   \n",
    "    return final\n",
    "\n",
    "# Apply the function to the DataFrame with tqdm\n",
    "tqdm.pandas()  # Initialize tqdm for pandas\n",
    "\n",
    "coco_logger = coco.logging.getLogger()\n",
    "coco_logger.setLevel(logging.CRITICAL)\n",
    "df[\"locations\"] = df[\"Text\"].progress_apply(extract_and_filter_locations)\n",
    "\n",
    "# Display the result\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "6d5bf378-5e1f-4e77-b78a-615e61f00ba6",
   "metadata": {},
   "outputs": [],
   "source": [
    "df.to_csv(r'C:\\Users\\Yasaman\\Arab_spring_scholarly_attention\\Validating\\bert_base_annotated.csv', index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "048f847d-de46-48fa-a86b-fc4438c66c2d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "overall accuracy (union) 0.724\n",
      "overall accuracy (intersection) 0.71\n"
     ]
    }
   ],
   "source": [
    "print('overall accuracy (union)', sum(df['locations']==df['union_annotation'])/1000)\n",
    "print('overall accuracy (intersection)', sum(df['locations']==df['intersection_annotation'])/1000)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "7396c1f3-a157-4ed3-95f0-0745b287d804",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\Users\\Yasaman\\AppData\\Local\\Temp\\ipykernel_28204\\2343847047.py:1: DeprecationWarning: DataFrameGroupBy.apply operated on the grouping columns. This behavior is deprecated, and in a future version of pandas the grouping columns will be excluded from the operation. Either pass `include_groups=False` to exclude the groupings or explicitly select the grouping columns after groupby to silence this warning.\n",
      "  sample_accuracies = df.groupby(\"SampleGroup\").apply(\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>SampleGroup</th>\n",
       "      <th>accuracy_union</th>\n",
       "      <th>accuracy_intersection</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>with_mention_arab</td>\n",
       "      <td>0.605000</td>\n",
       "      <td>0.570000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>with_mention</td>\n",
       "      <td>0.632143</td>\n",
       "      <td>0.582143</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>field_20</td>\n",
       "      <td>0.896154</td>\n",
       "      <td>0.894231</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "         SampleGroup  accuracy_union  accuracy_intersection\n",
       "0  with_mention_arab        0.605000               0.570000\n",
       "1       with_mention        0.632143               0.582143\n",
       "2           field_20        0.896154               0.894231"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sample_accuracies = df.groupby(\"SampleGroup\").apply(\n",
    "    lambda x: pd.Series({\n",
    "        \"accuracy_union\": (x.apply(lambda row: set(row['locations']) == set(row['union_annotation']), axis=1).mean()),\n",
    "        \"accuracy_intersection\": (x.apply(lambda row: set(row['locations']) == set(row['intersection_annotation']), axis=1).mean())\n",
    "    })\n",
    ").reset_index()\n",
    "\n",
    "sample_accuracies = sample_accuracies.sort_values(by='accuracy_union').reset_index(drop=True)\n",
    "sample_accuracies"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "a7cc1929-71c5-4f2d-9da4-cee12fbc786c",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\Users\\Yasaman\\AppData\\Local\\Temp\\ipykernel_28204\\388812476.py:1: DeprecationWarning: DataFrameGroupBy.apply operated on the grouping columns. This behavior is deprecated, and in a future version of pandas the grouping columns will be excluded from the operation. Either pass `include_groups=False` to exclude the groupings or explicitly select the grouping columns after groupby to silence this warning.\n",
      "  sample_accuracies = df.groupby(\"SampleGroup\").apply(\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>SampleGroup</th>\n",
       "      <th>jaccard_union</th>\n",
       "      <th>jaccard_intersection</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>with_mention</td>\n",
       "      <td>0.715000</td>\n",
       "      <td>0.668810</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>with_mention_arab</td>\n",
       "      <td>0.795681</td>\n",
       "      <td>0.735680</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>field_20</td>\n",
       "      <td>0.908045</td>\n",
       "      <td>0.904519</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "         SampleGroup  jaccard_union  jaccard_intersection\n",
       "0       with_mention       0.715000              0.668810\n",
       "1  with_mention_arab       0.795681              0.735680\n",
       "2           field_20       0.908045              0.904519"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sample_accuracies = df.groupby(\"SampleGroup\").apply(\n",
    "    lambda x: pd.Series({\n",
    "        \"jaccard_union\": (\n",
    "            x.apply(\n",
    "                lambda row: len(set(row['locations']).intersection(set(row['union_annotation']))) /\n",
    "                            len(set(row['locations']).union(set(row['union_annotation'])))\n",
    "                if len(set(row['locations']).union(set(row['union_annotation']))) > 0 else 1,\n",
    "                axis=1\n",
    "            ).mean()\n",
    "        ),\n",
    "        \"jaccard_intersection\": (\n",
    "            x.apply(\n",
    "                lambda row: len(set(row['locations']).intersection(set(row['intersection_annotation']))) /\n",
    "                            len(set(row['locations']).union(set(row['intersection_annotation'])))\n",
    "                if len(set(row['locations']).union(set(row['intersection_annotation']))) > 0 else 1,\n",
    "                axis=1\n",
    "            ).mean()\n",
    "        )\n",
    "    })\n",
    ").reset_index()\n",
    "\n",
    "sample_accuracies = sample_accuracies.sort_values(by='jaccard_union').reset_index(drop=True)\n",
    "sample_accuracies"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
